{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a [pytorch](http://pytorch.org/) implementation of CapsNet, described in the paper [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829) - by [Sara Sabour](https://arxiv.org/find/cs/1/au:+Sabour_S/0/1/0/all/0/1), [Nicholas Frosst](https://arxiv.org/find/cs/1/au:+Frosst_N/0/1/0/all/0/1) and [Geoffrey E Hinton](https://arxiv.org/find/cs/1/au:+Hinton_G/0/1/0/all/0/1).\n",
    "\n",
    "All images and text in the following sections are extracted directly from the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "from tqdm.auto import tqdm\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training is performed on 28 x 28 MNIST images that have been shifted by up to 2 pixels in each direction with zero padding. No other data augmentation/deformation is used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_SIZE = (1, 28, 28)\n",
    "transforms = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.RandomCrop(INPUT_SIZE[1:], padding=2),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset has 60K and 10K images for training and testing respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_dataset = torchvision.datasets.MNIST('.', train=True, download=True, transform=transforms)\n",
    "tst_dataset = torchvision.datasets.MNIST('.', train=False, download=True, transform=transforms)\n",
    "print('Images for training: %d' % len(trn_dataset))\n",
    "print('Images for testing: %d' % len(tst_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128 # Batch size not specified in the paper\n",
    "trn_loader = torch.utils.data.DataLoader(trn_dataset, BATCH_SIZE, shuffle=True)\n",
    "tst_loader = torch.utils.data.DataLoader(tst_dataset, BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define CapsNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conv1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Conv1 has 256, 9 x 9 convolution kernels with a stride of 1 and ReLU activation. This layer converts pixel intensities to the activities of local feature detectors that are then used as inputs to the *primary* capsules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv1(torch.nn.Module):\n",
    "    def __init__(self, in_channels, out_channels=256, kernel_size=9):\n",
    "        super(Conv1, self).__init__()\n",
    "        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size)\n",
    "        self.activation = torch.nn.ReLU()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.activation(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Primary Capsules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The second layer (PrimaryCapsules) is a convolutional capsule layer with 32 channels of convolutional 8D capsules (*i.e.* each primary capsule contains 8 convolutional units with a $[9 \\times 9]$ kernel and a stride of 2). Each primary capsule output sees the outputs of all $[256 \\times 81]$ Conv1 units whose receptive fields overlap with the location of the center of the capsule. In total PrimaryCapsules has $[32 \\times 6 \\times 6]$ capsule outputs (each output is an 8D vector) and each capsule in the $[6 \\times 6]$ grid is sharing their weights with each other. One can see PrimaryCapsules as a Convolution layer with Eq. 1 as its block non-linearity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrimaryCapsules(torch.nn.Module):\n",
    "    def __init__(self, input_shape=(256, 20, 20), capsule_dim=8,\n",
    "                 out_channels=32, kernel_size=9, stride=2):\n",
    "        super(PrimaryCapsules, self).__init__()\n",
    "        self.input_shape = input_shape\n",
    "        self.capsule_dim = capsule_dim\n",
    "        self.out_channels = out_channels\n",
    "        self.kernel_size = kernel_size\n",
    "        self.stride = stride\n",
    "        self.in_channels = self.input_shape[0]\n",
    "        \n",
    "        self.conv = torch.nn.Conv2d(\n",
    "            self.in_channels,\n",
    "            self.out_channels * self.capsule_dim,\n",
    "            self.kernel_size,\n",
    "            self.stride\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = x.permute(0, 2, 3, 1).contiguous()\n",
    "        x = x.view(-1, x.size()[1], x.size()[2], self.out_channels, self.capsule_dim)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Routing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We want the length of the output vector of a capsule to represent the probability that the entity represented by the capsule is present in the current input. We therefore use a non-linear \"squashing\" function to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1. We leave it to discriminative learning to make good use of this non-linearity.\n",
    "\n",
    "\\begin{equation*}\n",
    "\\mathbf{v}_j = \\frac{||\\mathbf{s}_j||^2}{1 + ||\\mathbf{s}_j||^2} \\frac{\\mathbf{s}_j}{||\\mathbf{s}_j||}\n",
    "\\end{equation*}\n",
    "\n",
    "where $\\mathbf{v}_j$ is the vector output of capsule $j$ and $\\mathbf{s}_j$ is its total input.\n",
    "\n",
    "For all but the first layer of capsules, the total input to a capsule $\\mathbf{s}_j$ is a weighted sum over all \"prediction vectors\" $\\mathbf{\\hat u}_{j|i}$ from the capsules in the layer below and is produced by multiplying the output $\\mathbf{u}_i$ of a capsule in the layer below by a weight matrix $\\mathbf{W}_{ij}$\n",
    "\n",
    "\\begin{equation*}\n",
    "\\mathbf{s}_j = \\sum_i c_{ij} \\mathbf{\\hat u}_{j|i}, \\quad \\mathbf{\\hat u}_{j|i} = \\mathbf{W}_{ij} \\mathbf{u}_i\n",
    "\\end{equation*}\n",
    "\n",
    "where the $c_{ij}$ are coupling coefficients that are determined by the iterative dynamic routing process.\n",
    "\n",
    "The coupling coefficients between capsule $i$ and all the capsules in the layer above sum to 1 and are determined by a \"routing softmax\" whose initial logits $b_{ij}$ are the log prior probabilities that capsule $i$ should be coupled to capsule $j$.\n",
    "\n",
    "\\begin{equation*}\n",
    "c_{ij} = \\frac{\\exp(b_{ij})}{\\sum_k \\exp(b_{ik})}\n",
    "\\end{equation*}\n",
    "\n",
    "The log priors can be learned discriminatively at the same time as all the other weights. They depend on the location and type of the two capsules but not on the current input image. The initial coupling coefficients are then iteratively refined by measuring the agreement between the current output $\\mathbf{v}_j$ of each capsule, $j$, in the layer above and the prediction $\\mathbf{\\hat u}_{j|i}$ made by capsule $i$.\n",
    "\n",
    "The agreement is simply the scalar product $a_{ij} = \\mathbf{v}_j \\cdot \\mathbf{\\hat u}_{j|i}$. This agreement is treated as if it was a log likelihood and is added to the initial logit, $b_{ij}$ before computing the new values for all the coupling coefficients linking capsule $i$ to higher level capsules.\n",
    "\n",
    "In convolutional capsule layers, each capsule outputs a local grid of vectors to each type of capsule in the layer above using different transformation matrices for each member of the grid as well as for each type of capsule."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Routing(torch.nn.Module):\n",
    "    def __init__(self, caps_dim_before=8, caps_dim_after=16,\n",
    "                 n_capsules_before=(6 * 6 * 32), n_capsules_after=10):\n",
    "        super(Routing, self).__init__()\n",
    "        self.n_capsules_before = n_capsules_before\n",
    "        self.n_capsules_after = n_capsules_after\n",
    "        self.caps_dim_before = caps_dim_before\n",
    "        self.caps_dim_after = caps_dim_after\n",
    "        \n",
    "        # Parameter initialization not specified in the paper\n",
    "        n_in = self.n_capsules_before * self.caps_dim_before\n",
    "        variance = 2 / (n_in)\n",
    "        std = np.sqrt(variance)\n",
    "        self.W = torch.nn.Parameter(\n",
    "            torch.randn(\n",
    "                self.n_capsules_before,\n",
    "                self.n_capsules_after,\n",
    "                self.caps_dim_after,\n",
    "                self.caps_dim_before) * std,\n",
    "            requires_grad=True)\n",
    "    \n",
    "    # Equation (1)\n",
    "    @staticmethod\n",
    "    def squash(s):\n",
    "        s_norm = torch.norm(s, p=2, dim=-1, keepdim=True)\n",
    "        s_norm2 = torch.pow(s_norm, 2)\n",
    "        v = (s_norm2 / (1.0 + s_norm2)) * (s / s_norm)\n",
    "        return v\n",
    "    \n",
    "    # Equation (2)\n",
    "    def affine(self, x):\n",
    "        x = self.W @ x.unsqueeze(2).expand(-1, -1, 10, -1).unsqueeze(-1)\n",
    "        return x.squeeze()\n",
    "    \n",
    "    # Equation (3)\n",
    "    @staticmethod\n",
    "    def softmax(x, dim=-1):\n",
    "        exp = torch.exp(x)\n",
    "        return exp / torch.sum(exp, dim, keepdim=True)\n",
    "    \n",
    "    # Procedure 1 - Routing algorithm.\n",
    "    def routing(self, u, r, l):\n",
    "        b = Variable(torch.zeros(u.size()[0], l[0], l[1]), requires_grad=False).cuda() # torch.Size([?, 1152, 10])\n",
    "        \n",
    "        for iteration in range(r):\n",
    "            c = Routing.softmax(b) # torch.Size([?, 1152, 10])\n",
    "            s = (c.unsqueeze(-1).expand(-1, -1, -1, u.size()[-1]) * u).sum(1) # torch.Size([?, 1152, 16])\n",
    "            v = Routing.squash(s) # torch.Size([?, 10, 16])\n",
    "            b += (u * v.unsqueeze(1).expand(-1, l[0], -1, -1)).sum(-1)\n",
    "        return v\n",
    "    \n",
    "    def forward(self, x, n_routing_iter):\n",
    "        x = x.view((-1, self.n_capsules_before, self.caps_dim_before))\n",
    "        x = self.affine(x) # torch.Size([?, 1152, 10, 16])\n",
    "        x = self.routing(x, n_routing_iter, (self.n_capsules_before, self.n_capsules_after))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The final Layer (DigitCaps) has one 16D capsule per digit class and each of these capsules receives input from all the capsules in the layer below.\n",
    "\n",
    "We have routing only between two consecutive capsule layers (e.g. PrimaryCapsules and DigitCaps).\n",
    "Since Conv1 output is 1D, there is no orientation in its space to agree on. Therefore, no routing is used between Conv1 and PrimaryCapsules. All the routing logits ($b_{ij}$) are initialized to zero. Therefore, initially a capsule output ($\\mathbf{u}_i$) is sent to all parent capsules ($\\mathbf{v}_0...\\mathbf{v}_9$) with equal probability ($c_{ij}$)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Norm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are using the length of the instantiation vector to represent the probability that a capsule’s entity exists. We would like the top-level capsule for digit class $k$ to have a long instantiation vector if and only if that digit is present in the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Norm(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Norm, self).__init__()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = torch.norm(x, p=2, dim=-1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "During training, we mask out all but the activity vector of the correct digit capsule. Then we use this activity vector to reconstruct the input image. The output of the digit capsule is fed into a decoder consisting of 3 fully connected layers that model the pixel intensities (...).\n",
    "\n",
    "<img src=\"./images/reconsArch.png\" width=\"500\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(torch.nn.Module):\n",
    "    def __init__(self, in_features, out_features, output_size=INPUT_SIZE):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.decoder = self.assemble_decoder(in_features, out_features)\n",
    "        self.output_size = output_size\n",
    "    \n",
    "    def assemble_decoder(self, in_features, out_features):\n",
    "        HIDDEN_LAYER_FEATURES = [512, 1024]\n",
    "        return torch.nn.Sequential(\n",
    "            torch.nn.Linear(in_features, HIDDEN_LAYER_FEATURES[0]),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(HIDDEN_LAYER_FEATURES[0], HIDDEN_LAYER_FEATURES[1]),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(HIDDEN_LAYER_FEATURES[1], out_features),\n",
    "            torch.nn.Sigmoid(),\n",
    "        )\n",
    "    \n",
    "    def forward(self, x, y):\n",
    "        x = x[np.arange(0, x.size()[0]), y.cpu().data.numpy(), :].cuda()\n",
    "        x = self.decoder(x)\n",
    "        x = x.view(*((-1,) + self.output_size))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CapsNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The architecture is shallow with only two convolutional layers and one fully connected layer.\n",
    "\n",
    "<img src=\"./images/capsulearch.png\" width=\"700\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CapsNet(torch.nn.Module):\n",
    "    def __init__(self, input_shape=INPUT_SIZE, n_routing_iter=3, use_reconstruction=True):\n",
    "        super(CapsNet, self).__init__()\n",
    "        assert len(input_shape) == 3\n",
    "        \n",
    "        self.input_shape = input_shape\n",
    "        self.n_routing_iter = n_routing_iter\n",
    "        self.use_reconstruction = use_reconstruction\n",
    "        \n",
    "        self.conv1 = Conv1(input_shape[0], 256, 9)\n",
    "        self.primary_capsules = PrimaryCapsules(\n",
    "            input_shape=(256, 20, 20),\n",
    "            capsule_dim=8,\n",
    "            out_channels=32,\n",
    "            kernel_size=9,\n",
    "            stride=2\n",
    "        )\n",
    "        self.routing = Routing(\n",
    "            caps_dim_before=8,\n",
    "            caps_dim_after=16,\n",
    "            n_capsules_before=6 * 6 * 32,\n",
    "            n_capsules_after=10\n",
    "        )\n",
    "        self.norm = Norm()\n",
    "        \n",
    "        if (self.use_reconstruction):\n",
    "            self.decoder = Decoder(16, int(np.prod(input_shape)))\n",
    "    \n",
    "    def n_parameters(self):\n",
    "        return np.sum([np.prod(x.size()) for x in self.parameters()])\n",
    "    \n",
    "    def forward(self, x, y=None):\n",
    "        conv1 = self.conv1(x)\n",
    "        primary_capsules = self.primary_capsules(conv1)\n",
    "        digit_caps = self.routing(primary_capsules, self.n_routing_iter)\n",
    "        scores = self.norm(digit_caps)\n",
    "        \n",
    "        if (self.use_reconstruction and y is not None):\n",
    "            reconstruction = self.decoder(digit_caps, y).view((-1,) + self.input_shape)\n",
    "            return scores, reconstruction\n",
    "        \n",
    "        return scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Loss Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Margin Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To allow for multiple digits, we use a separate margin loss, $L_k$ for each digit capsule, $k$:\n",
    "\n",
    "\\begin{equation*}\n",
    "L_k = T_k \\max(0, m^+ - ||\\mathbf{v}_k||)^2 + \\lambda (1 - T_k) \\max(0, ||\\mathbf{v}_k|| - m^-)^2\n",
    "\\end{equation*}\n",
    "\n",
    "where $T_k = 1$ iff a digit of class $k$ is present and $m^+ = 0.9$ and $m^- = 0.1$. The $\\lambda$ down-weighting of the loss for absent digit classes stops the initial learning from shrinking the lengths of the activity vectors of all the digit capsules. We use $\\lambda = 0.5$. The total loss is simply the sum of the losses of all digit capsules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_categorical(y, num_classes):\n",
    "    \"\"\" 1-hot encodes a tensor \"\"\"\n",
    "    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]\n",
    "    if (y.is_cuda):\n",
    "        return new_y.cuda()\n",
    "    return new_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MarginLoss(torch.nn.Module):\n",
    "    def __init__(self, m_pos=0.9, m_neg=0.1, lamb=0.5):\n",
    "        super(MarginLoss, self).__init__()\n",
    "        self.m_pos = m_pos\n",
    "        self.m_neg = m_neg\n",
    "        self.lamb = lamb\n",
    "    \n",
    "    # Equation (4)\n",
    "    def forward(self, scores, y):\n",
    "        y = Variable(to_categorical(y, 10))\n",
    "        \n",
    "        Tc = y.float()\n",
    "        loss_pos = torch.pow(torch.clamp(self.m_pos - scores, min=0), 2)\n",
    "        loss_neg = torch.pow(torch.clamp(scores - self.m_neg, min=0), 2)\n",
    "        loss = Tc * loss_pos + self.lamb * (1 - Tc) * loss_neg\n",
    "        loss = loss.sum(-1)\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reconstruction Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use an additional reconstruction loss to encourage the digit capsules to encode the instantiation parameters of the input digit. (...) We minimize the sum of squared differences between the outputs of the logistic units and the pixel intensities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SumSquaredDifferencesLoss(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SumSquaredDifferencesLoss, self).__init__()\n",
    "    \n",
    "    def forward(self, x_reconstruction, x):\n",
    "        loss = torch.pow(x - x_reconstruction, 2).sum(-1).sum(-1)\n",
    "        return loss.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Total Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We scale down this reconstruction loss by $0.0005$ so that it does not dominate the margin loss during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CapsNetLoss(torch.nn.Module):\n",
    "    def __init__(self, reconstruction_loss_scale=0.0005):\n",
    "        super(CapsNetLoss, self).__init__()\n",
    "        self.digit_existance_criterion = MarginLoss()\n",
    "        self.digit_reconstruction_criterion = SumSquaredDifferencesLoss()\n",
    "        self.reconstruction_loss_scale = reconstruction_loss_scale\n",
    "    \n",
    "    def forward(self, x, y, x_reconstruction, scores):\n",
    "        margin_loss = self.digit_existance_criterion(y_pred.cuda(), y)\n",
    "        reconstruction_loss = self.reconstruction_loss_scale *\\\n",
    "                              self.digit_reconstruction_criterion(x_reconstruction, x)\n",
    "        loss = margin_loss + reconstruction_loss\n",
    "        return loss, margin_loss, reconstruction_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CapsNet().cuda()\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnetwork."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Number of Parameters: %d' % model.n_parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = CapsNetLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(...) we use the Adam optimizer with its TensorFlow default parameters, including the exponentially decaying learning rate, to minimize the sum of the margin losses in Eq. 4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def exponential_decay(optimizer, learning_rate, global_step, decay_steps, decay_rate, staircase=False):\n",
    "    if (staircase):\n",
    "        decayed_learning_rate = learning_rate * np.power(decay_rate, global_step // decay_steps)\n",
    "    else:\n",
    "        decayed_learning_rate = learning_rate * np.power(decay_rate, global_step / decay_steps)\n",
    "        \n",
    "    for param_group in optimizer.param_groups:\n",
    "        param_group['lr'] = decayed_learning_rate\n",
    "    \n",
    "    return optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE = 0.001\n",
    "optimizer = torch.optim.Adam(\n",
    "    model.parameters(),\n",
    "    lr=LEARNING_RATE,\n",
    "    betas=(0.9, 0.999),\n",
    "    eps=1e-08\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_checkpoint(epoch, train_accuracy, test_accuracy, model, optimizer, path=None):\n",
    "    if (path is None):\n",
    "        path = 'checkpoint-%f-%04d.pth' % (test_accuracy, epoch)\n",
    "    state = {\n",
    "        'epoch': epoch,\n",
    "        'train_accuracy': train_accuracy,\n",
    "        'test_accuracy': test_accuracy,\n",
    "        'model_state_dict': model.state_dict(),\n",
    "        'optimizer_state_dict': optimizer.state_dict(),\n",
    "    }\n",
    "    torch.save(state, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_example(model, x, y, x_reconstruction, y_pred):\n",
    "    x = x.squeeze().cpu().data.numpy()\n",
    "    y = y.cpu().data.numpy()\n",
    "    x_reconstruction = x_reconstruction.squeeze().cpu().data.numpy()\n",
    "    _, y_pred = torch.max(y_pred, -1)\n",
    "    y_pred = y_pred.cpu().data.numpy()\n",
    "    \n",
    "    fig, ax = plt.subplots(1, 2)\n",
    "    ax[0].imshow(x, cmap='Greys')\n",
    "    ax[0].set_title('Input: %d' % y)\n",
    "    ax[1].imshow(x_reconstruction, cmap='Greys')\n",
    "    ax[1].set_title('Output: %d' % y_pred)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, loader):\n",
    "    metrics = defaultdict(lambda:list())\n",
    "    for batch_id, (x, y) in tqdm(enumerate(loader), total=len(loader)):\n",
    "        x = Variable(x).float().cuda()\n",
    "        y = Variable(y).cuda()\n",
    "        y_pred, x_reconstruction = model(x, y)\n",
    "        _, y_pred = torch.max(y_pred, -1)\n",
    "        metrics['accuracy'].append((y_pred == y).cpu().data.numpy())\n",
    "    metrics['accuracy'] = np.concatenate(metrics['accuracy']).mean()\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_epoch = 0\n",
    "global_step = 0\n",
    "best_tst_accuracy = 0.0\n",
    "history = defaultdict(lambda:list())\n",
    "COMPUTE_TRN_METRICS = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 1500 # Number of epochs not specified in the paper\n",
    "for epoch in range(n_epochs):\n",
    "    print('Epoch %d (%d/%d):' % (global_epoch + 1, epoch + 1, n_epochs))\n",
    "    \n",
    "    for batch_id, (x, y) in tqdm(enumerate(trn_loader), total=len(trn_loader)):\n",
    "        optimizer = exponential_decay(optimizer, LEARNING_RATE, global_epoch, 1, 0.90) # Configurations not specified in the paper\n",
    "        \n",
    "        x = Variable(x).float().cuda()\n",
    "        y = Variable(y).cuda()\n",
    "        \n",
    "        y_pred, x_reconstruction = model(x, y)\n",
    "        loss, margin_loss, reconstruction_loss = criterion(x, y, x_reconstruction, y_pred.cuda())\n",
    "        \n",
    "        history['margin_loss'].append(margin_loss.cpu().data.numpy())\n",
    "        history['reconstruction_loss'].append(reconstruction_loss.cpu().data.numpy())\n",
    "        history['loss'].append(loss.cpu().data.numpy())\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        global_step += 1\n",
    "\n",
    "    trn_metrics = test(model, trn_loader) if COMPUTE_TRN_METRICS else None\n",
    "    tst_metrics = test(model, tst_loader)\n",
    "    \n",
    "    print('Margin Loss: %f' % history['margin_loss'][-1])\n",
    "    print('Reconstruction Loss: %f' % history['reconstruction_loss'][-1])\n",
    "    print('Loss: %f' % history['loss'][-1])\n",
    "    print('Train Accuracy: %f' % (trn_metrics['accuracy'] if COMPUTE_TRN_METRICS else 0.0))\n",
    "    print('Test Accuracy: %f' % tst_metrics['accuracy'])\n",
    "    \n",
    "    print('Example:')\n",
    "    idx = np.random.randint(0, len(x))\n",
    "    show_example(model, x[idx], y[idx], x_reconstruction[idx], y_pred[idx])\n",
    "    \n",
    "    if (tst_metrics['accuracy'] >= best_tst_accuracy):\n",
    "        best_tst_accuracy = tst_metrics['accuracy']\n",
    "        save_checkpoint(\n",
    "            global_epoch + 1,\n",
    "            trn_metrics['accuracy'] if COMPUTE_TRN_METRICS else 0.0,\n",
    "            tst_metrics['accuracy'],\n",
    "            model,\n",
    "            optimizer\n",
    "        )\n",
    "    global_epoch += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss Curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_avg_curve(y, n_points_avg):\n",
    "    avg_kernel = np.ones((n_points_avg,)) / n_points_avg\n",
    "    rolling_mean = np.convolve(y, avg_kernel, mode='valid')\n",
    "    return rolling_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_points_avg = 10\n",
    "n_points_plot = 1000\n",
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "curve = np.asarray(history['loss'])[-n_points_plot:]\n",
    "avg_curve = compute_avg_curve(curve, n_points_avg)\n",
    "plt.plot(avg_curve, '-g')\n",
    "\n",
    "curve = np.asarray(history['margin_loss'])[-n_points_plot:]\n",
    "avg_curve = compute_avg_curve(curve, n_points_avg)\n",
    "plt.plot(avg_curve, '-b')\n",
    "\n",
    "curve = np.asarray(history['reconstruction_loss'])[-n_points_plot:]\n",
    "avg_curve = compute_avg_curve(curve, n_points_avg)\n",
    "plt.plot(avg_curve, '-r')\n",
    "\n",
    "plt.legend(['Total Loss', 'Margin Loss', 'Reconstruction Loss'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Done!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "toc": {
   "nav_menu": {
    "height": "177px",
    "width": "219px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {
    "height": "659px",
    "left": "0px",
    "right": "1007.8px",
    "top": "133px",
    "width": "241px"
   },
   "toc_section_display": "block",
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
